import math
import time
from typing import Union
import argparse

import numpy as np
import torch
import torch.nn as nn

from einops import rearrange
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from hypnettorch.mnets import MLP
from hypnettorch.hnets import HMLP

from env import PushTEnv, NormalizeActionWrapper, SpaceConversionWrapper, StateStackWrapper



DECODER_SIZE_DICT = {
    'xs': [50, 50],
    's': [100, 100],
    'm': [200, 200],
    'l': [400, 400]
}

def vae_loss(x, x_recon, z_mean, z_logvar, kl_coeff = 1e-6):
    recon_loss = nn.MSELoss()(x_recon, x)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return {
        "recon_loss" : recon_loss,
        "kl_loss": kl_coeff * kl_loss
    }

class VAEHyperNetModel(nn.Module):
    def __init__(
            self, 
            policy, 
            state_dim, 
            action_dim, 
            latent_dim,
            vae_decoder_size, 
            traj_len, 
            stochastic_decoder = False, 
            kl_coeff = 1e-6
        ):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.latent_dim = latent_dim
        self.vae_decoder_size = vae_decoder_size
        self.traj_len = traj_len
        self.stochastic_decoder = stochastic_decoder
        self.kl_coeff = kl_coeff
        desired_shape = policy.param_shapes
        self.hnet = HMLP(desired_shape, cond_in_size=0, uncond_in_size=latent_dim, layers=DECODER_SIZE_DICT[self.vae_decoder_size])
        if self.stochastic_decoder:
            self.log_var = nn.Parameter(torch.zeros(1))
        self.criterion = nn.MSELoss()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(self.traj_len * (self.state_dim + self.action_dim), 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, self.latent_dim * 2)  # output both mean and log variance
        )

    def encode(self, x):
        # Encoding
        x_flattened = rearrange(x, "b t sa -> b (t sa)")
        encoded = self.encoder(x_flattened)
        mu, log_var = torch.chunk(encoded, 2, dim=-1)  # split the encoder output into mu and log_var components
        return mu, log_var

    def decode(self, x):
        weights = self.hnet.forward(uncond_input=x)
        return weights

    def forward(self, batch_obs, policy, batch_act):
        batch_size = batch_act.shape[0]
        x = torch.cat([batch_obs, batch_act], -1)

        mu, log_var = self.encode(x)

        z = reparameterize(mu, log_var)

        weights_mean = self.hnet.forward(uncond_input=z)
        sampled_actions = []
        for i in range(batch_size):
            weight_mean_ = weights_mean[i]
            if self.stochastic_decoder:
                sampled_weights = []
                for layer_mean in weight_mean_:
                    layer_log_var = self.log_var

                    # Compute standard deviation from log_var
                    std = torch.exp(0.5 * layer_log_var)

                    # Sample epsilon from standard normal distribution
                    epsilon = torch.randn_like(std)
                    
                    # Reparameterization: sample from N(mean, var)
                    sample = layer_mean + epsilon * std

                    sampled_weights.append(sample)
            
            else:
                sampled_weights = weight_mean_

            # get actions from policy
            sampled_actions.append(policy(batch_obs[i], sampled_weights))

        reconstructed_actions = torch.stack(sampled_actions)

        loss = vae_loss(batch_act, reconstructed_actions, mu, log_var, self.kl_coeff)

        return reconstructed_actions, loss

    def get_policy_weights(self, device, data = None):
        if data is None:
            # sample z from normal dist
            z = reparameterize(torch.Tensor([[0]*self.latent_dim]), torch.Tensor([[0]*self.latent_dim])).to(device)
        else:
            states = data["states"]  # [8, 64, 17]
            actions = data["actions"]  # [8, 64, 6]
            x = torch.cat([states, actions], -1).to(device)
            x_flattened = rearrange(x, "b t sa -> b (t sa)")
            encoded = self.encoder(x_flattened)
            mu, log_var = torch.chunk(encoded, 2, dim=-1)  # split the encoder output into mu and log_var components
            z = reparameterize(mu, log_var)  # [8, latent_dim]

        hnet_out = self.hnet.forward(uncond_input=z)  # [8, policy_weights]
        # weights_mean = hnet_out[:-1]
        # weights_logvar = hnet_out[-1]
        # weights = reparameterize(weights_mean, weights_logvar)

        return hnet_out



def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std



class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x
    


def main(args):
    vae_ckpt_path = "vae.pt"
    diffusion_ckpt_path = "diff.pt"

    # set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    policy = MLP(
        n_in=10, 
        n_out=2, 
        hidden_layers=[256, 256], 
        no_weights=True
    )
    vae = VAEHyperNetModel(
        policy, 
        state_dim=10, 
        action_dim=2, 
        latent_dim=256, 
        vae_decoder_size='l',
        traj_len=16, 
    ).to(device)

    noise_pred_net = ConditionalUnet1D(
        input_dim=1,
        global_cond_dim=10 + 1,
        diffusion_step_embed_dim=256,
        down_dims=[32, 64, 128],
        kernel_size=5
    )

    num_diffusion_iters = 100
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=num_diffusion_iters,
        # the choise of beta schedule has big impact on performance
        # we found squared cosine works the best
        beta_schedule='squaredcos_cap_v2',
        # clip output to [-1,1] to improve stability
        clip_sample=True,
        # our network predicts noise (instead of denoised action)
        prediction_type='epsilon'
    )

    checkpoint = torch.load(vae_ckpt_path)
    vae.load_state_dict(checkpoint['model_state_dict'])
    for param in vae.parameters():
        param.requires_grad = False
    noise_pred_net.load_state_dict(torch.load(diffusion_ckpt_path)['state_dict'])

    vae = vae.to(device)
    noise_pred_net = noise_pred_net.to(device)


    stats = {
        "obs": {
            "max": np.array([496.14618, 510.9579, 439.9153, 485.6641, 6.2830877], dtype=np.float32),
            "min": np.array([13.456424, 32.938293, 57.471767, 108.27995, 0.00021559125], dtype=np.float32),
        },
        "action": {
            "max": np.array([511.0, 511.0], dtype=np.float32),
            "min": np.array([12.0, 25.0], dtype=np.float32),
        },
    }

    env = PushTEnv(
        seed=args.seed,
    )
    env = NormalizeActionWrapper(env, stats)
    env = SpaceConversionWrapper(env)
    env = StateStackWrapper(env, 2)


    obs = env.reset()[0]
    obs = obs[None]
    B = 1
    number_of_perturbations = 10
    noise_scale = args.perturbation
    latent_scaling_factor = 0.18215
    task_id = torch.tensor([[0]]).to(device)
    sleep_dt = 0.05
    action_horizon = 16
    step_idx = 0
    max_traj_len = 256

    state = torch.tensor(obs).to(device).to(torch.float32)
    done = np.zeros(B, dtype=bool)
    steps_to_perturb = [
        # choose 10 steps to apply perturbation, without repeating
        np.random.choice(256, number_of_perturbations, replace=False)
        for _ in range(B)
    ]
    env.setattr("move_t_range", noise_scale)
    env.setattr("iter_indices_to_move_t", steps_to_perturb[0])

    while not np.all(done):
        sample = torch.randn(
            (B, 256),
            device=device
        )
        sample = rearrange(sample, "b (t s) -> b t s", t = 256)
        noise_scheduler.set_timesteps(num_diffusion_iters)
        obs_cond = torch.tensor(obs, dtype=torch.float32, device=device)
        obs_cond = torch.cat([obs_cond, task_id], dim=1)
        for k in noise_scheduler.timesteps:
            with torch.no_grad():
                noise_pred = noise_pred_net(
                    sample=sample,
                    timestep=k,
                    global_cond=obs_cond
                )
            sample = noise_scheduler.step(
                model_output=noise_pred,
                timestep=k,
                sample=sample
            ).prev_sample

        sample = rearrange(sample, "b t s -> b (t s)") / latent_scaling_factor
        policy_weights = vae.decode(sample)

        time.sleep(10*sleep_dt)

        for j in range(action_horizon):
            if np.all(done):
                break

            if state.shape[0] == 1:
                pred_acts = policy(state[0], policy_weights).detach().cpu().numpy()
            else:
                pred_acts = torch.stack([
                            policy(state[i], policy_weights[i]) for i in range(B)
                        ]).detach().cpu().numpy()         

        
            next_obs, reward, done_env, trunc, info = env.step(pred_acts)

            env.render()
            time.sleep(sleep_dt)

            next_obs = next_obs[None]
            reward = np.array([reward])
            done_env = np.array([done_env])

            state = torch.tensor(next_obs).to(device).to(torch.float32)

            next_obs = np.array(next_obs)
            reward = np.array(reward)
            done_env = np.array(done_env)


            done = np.logical_or(done, done_env)


            # update next obs for the env that hasn't ended
            step_idx += 1
            if step_idx >= max_traj_len:
                done[:] = True
                break
        if np.all(done):
            break
        obs = next_obs

    env.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--perturbation", type=int, default=0, help="Perturbation amount, can be 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100")
    parser.add_argument("--seed", type=int, default=1, help="Seed for the environment")
    args = parser.parse_args()
    main(args)